'''
一个微缩版的DenseNet,因为学习使用,本人用的是普通的CPU。所以这里搭建了一个微缩版的Densenet，结构都有，只是层数和参数比较少
使用方法
用ImgToDataset.py生成数据集和测试集后可以直接运行使用
运行后手工调用train()
训练时自动测试测试集并保存训练结果
调用训练结果时用restore()，参数填入要读取的saver 例如 restore('-99')

注：训练的时候可以增加一个全连接层并加入centerloss 以实现one-shot
本次没有加入centerloss，待后续增大数据集后实现
'''
import os
import tensorflow as tf
import VoiceToImg
import matplotlib.pyplot as plt

#输入维度
n_width = 119
n_height = 129
n_channel = 3
n_input = n_width*n_height*n_channel

#输出维度
n_output = 6


#每层卷积输出通道数
k = 8

#第一次卷积产生特征通道数
k0 = 32

#每DenseNet块的厚度
bl = 8

#DenseNet块数
block_count = 6

#一次一共训练20000张图 total_batch*batch_size*training_epochs
training_epochs = 100

#训练total_batch*batch_size张图打一次精度
total_batch = 20

#一次训练送入的数据量
batch_size = 20

#一次训练送入center class的数目
centerCount = 10

#一次随即抽泣200个图测试精度
test_batch_size = 200


#特征向量维度
feacher_len = 32
#第一块参数集合
blocks = []

#初始化块参数
for i in range(block_count):
    block = {}
    for j in range(bl):
        item_name = 'con_' + str(i) + '_' + str(j) + '_1'
        feature_count = (int)(k0 + j * k + (i * bl * k) / 2)
        block[item_name] = tf.Variable(tf.random_normal([1, 1, feature_count, k], stddev = 0.1), 
                dtype=tf.float32, name = item_name)
        
        item_name = 'bia_' + str(i) + '_' + str(j) + '_1'
        block[item_name] = tf.Variable(tf.random_normal([k], stddev = 0.1), 
                dtype=tf.float32, name = item_name)
        
        item_name = 'con_' + str(i) + '_' + str(j) + '_3'
        block[item_name] = tf.Variable(tf.random_normal([3, 3, k, k], stddev = 0.1), 
                dtype=tf.float32, name = item_name)
        
        item_name = 'bia_' + str(i) + '_' + str(j) + '_3'
        block[item_name] = tf.Variable(tf.random_normal([k], stddev = 0.1), 
                dtype=tf.float32, name = item_name)
    blocks.append(block)
#层参数 
weights = {
        'wc0' : tf.Variable(tf.random_normal([3, 3, n_channel, k0], stddev = 0.1), 
                            dtype=tf.float32, name = 'wc0')
    }
biases = {
        'bc0' : tf.Variable(tf.random_normal([k0], stddev = 0.1), 
                            dtype=tf.float32, name = 'bc0')
        }
for i in range(block_count):
    wc_name = 'wc' + str(i + 1)
    feature_count = (int)(k0 + i * bl * k / 2 + bl * k)
    out_count = (int)(k0 + (i + 1) * bl * k / 2)
    weights[wc_name] = tf.Variable(tf.random_normal([1, 1, feature_count, out_count], stddev = 0.1), 
                          dtype=tf.float32, name = wc_name)
    bc_name = 'bc' + str(i + 1)
    biases[bc_name] = tf.Variable(tf.random_normal([out_count], stddev = 0.1),
                        dtype=tf.float32, name = bc_name)

#输出层1输出特征向量
block = {}
item_name = 'con_' + str(block_count + 1)
feature_count = (int)(k0 + block_count * bl * k / 2)
block[item_name] = tf.Variable(tf.random_normal([1, 1, feature_count, feacher_len], stddev = 0.1), 
        dtype=tf.float32, name = item_name)

item_name = 'bia_' +  str(block_count + 1)
block[item_name] = tf.Variable(tf.random_normal([feacher_len], stddev = 0.1), 
        dtype=tf.float32, name = item_name)
blocks.append(block)   

#输出层2输出softmax
block = {}
item_name = 'con_' + str(block_count + 2)
block[item_name] = tf.Variable(tf.random_normal([1, 1, feacher_len, n_output], stddev = 0.1), 
        dtype=tf.float32, name = item_name)

item_name = 'bia_' +  str(block_count + 2)
block[item_name] = tf.Variable(tf.random_normal([n_output], stddev = 0.1), 
        dtype=tf.float32, name = item_name)
blocks.append(block)    
          
#dense块运算
def DenseBlockRun(inputs, lvl, istraining):
    _concat = inputs
    for i in range(bl):
        _out = _concat
        #b_name = 'b_nm_' + str(lvl) + '_' + str(i)  + '_1'
        #_out = tf.layers.batch_normalization(_out, training=istraining, name = b_name)
        
        item_name =  'con_' + str(lvl) + '_' + str(i) + '_1'
        _out = tf.nn.conv2d(_out, blocks[lvl][item_name], strides=[1, 1, 1, 1], padding='SAME')
        
        item_name =  'bia_' + str(lvl) + '_' + str(i) + '_1'
        _out = tf.nn.relu(tf.nn.bias_add(_out, blocks[lvl][item_name]))
        
        #b_name = 'b_nm_' + str(lvl) + '_' + str(i) + '_2'
        #_out = tf.layers.batch_normalization(_out, training=istraining, name = b_name)
        
        item_name =  'con_' + str(lvl) + '_' + str(i) + '_3'
        _out = tf.nn.conv2d(_out, blocks[lvl][item_name], strides=[1, 1, 1, 1], padding='SAME')
        
        item_name =  'bia_' + str(lvl) + '_' + str(i) + '_3'
        _out = tf.nn.relu(tf.nn.bias_add(_out, blocks[lvl][item_name]))
        _concat = tf.concat([_concat, _out], 3)
    return _concat

#前向传播函数
def MyDenseNetRun(inputs, istraining = False):
    _out = inputs
     #conv layer 0 119*129*3 -> 60*65*32
    _out = tf.nn.conv2d(_out, weights['wc0'], strides=[1, 1, 1, 1], padding='SAME')
    _out = tf.nn.relu(tf.nn.bias_add(_out, biases['bc0']))
    _out = tf.nn.max_pool(_out, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    #60*65*32 -> 30*33*64
    #30*33*64 -> 15*17*96
    #15*17*96 -> 8*9*128
    #8*9*128 -> 4*5*160
    #4*5*160 -> 2*3*192
    #2*3*192 -> 1*2*224
    for i in range(block_count):
        _out = DenseBlockRun(_out, i, istraining)
        #_out = tf.layers.batch_normalization(_out, training=istraining, name = 'bn' + str(i + 1))
        _out = tf.nn.conv2d(_out, weights['wc' + str(i + 1)], strides=[1, 1, 1, 1], padding='SAME')
        _out = tf.nn.relu(tf.nn.bias_add(_out, biases['bc' + str(i + 1)]))
        _out = tf.nn.max_pool(_out, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    #1*2*224 -> 1*2*n_output
    item_name =  'con_' + str(block_count + 1)
    _out = tf.nn.conv2d(_out, blocks[block_count][item_name], strides=[1, 1, 1, 1], padding='SAME')
    
    item_name =  'bia_' +  str(block_count + 1)
    _out = tf.nn.bias_add(_out, blocks[block_count][item_name])
    
    #1*2*feacher_len -> 1*1*feacher_len 输出特征向量
    feacher = tf.reduce_mean(_out, [1, 2], keep_dims=True)
    #1*1*feacher_len  -> 1*1*n_output
    item_name =  'con_' + str(block_count + 2)
    _out = tf.nn.conv2d(_out, blocks[block_count + 1][item_name], strides=[1, 1, 1, 1], padding='SAME')
    
    item_name =  'bia_' +  str(block_count + 2)
    _out = tf.nn.bias_add(_out, blocks[block_count + 1][item_name])
    
    #1*2*feacher_len -> 1*1*feacher_len 输出特征向量
    _out = tf.reduce_mean(_out, [1, 2])
    return [_out, feacher]

x = tf.placeholder("float", [None, None, None, None], name='x')
y = tf.placeholder("float", [None, n_output], name='y')
isTrain = tf.placeholder("bool", name='isTrain')
#one = tf.constant(1, dtype=tf.float32)
pred = MyDenseNetRun(x, isTrain)

def center_loss(feachers):
    feachers = tf.reshape(feachers, shape = [-1, feacher_len])
    feachers = feachers[0:centerCount,:]
    center = tf.reduce_mean(feachers, 0)
    loss = tf.reduce_mean(tf.square(feachers - center))
    return loss
    
cost_dis = center_loss(feachers=pred[1])
cost_num = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred[0], labels=y))

cost = cost_num + cost_dis

net_optimizer = tf.train.AdamOptimizer(learning_rate = 0.001, name = 'Adam')
train_op = net_optimizer.minimize(cost)
#net_optimizer = tf.train.AdamOptimizer(learning_rate = 0.001, name = 'Adam')
#update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
#with tf.control_dependencies(update_ops):
#    train_op = net_optimizer.minimize(cost)
   
corr = tf.equal(tf.argmax(pred[0], 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(corr, tf.float32))

#LAUNCH THE GRAPH
init = tf.global_variables_initializer()
#saver参数
global_vars = tf.global_variables()
saver = tf.train.Saver(global_vars)
sess = tf.Session()

coordinator = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)
sess.run(init)

def train():
    times = 0
    plotView = []
    plotAcc = []
    for epoch in range(training_epochs):
        avg_cost = 0
        max_cost = 0
        min_cost = 100
        #ITERATION
        for i in range(total_batch):
            label_batch_v, image_batch_v = VoiceToImg.GetBatchData(batch_size, centerCount)
            indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)
            concated = tf.concat([indices, label_batch_v],1)
            onehot_labels = tf.sparse_to_dense(concated, tf.stack([batch_size, n_output]), 1.0, 0.0)
            onehot = sess.run(onehot_labels)
            cost_once, _ = sess.run([cost, train_op], feed_dict = {
                    x : image_batch_v, y : onehot, isTrain : True})
#                print('traning:', str(cost_once), 'same:', str(same))
            print('traningonce:loss is', str(cost_once), '---', str(times * total_batch + i))
            
            if max_cost < cost_once:
                max_cost = cost_once
            if min_cost > cost_once:
                min_cost = cost_once
            avg_cost += cost_once
        times += 1
        avg_cost = avg_cost / total_batch
        #DISPLAY
        test_acc = testOnce()
        #第一个epoch不输出图
        if epoch > 0:
            plotView.append([avg_cost, min_cost, max_cost])
            plotAcc.append(test_acc)
            plt.plot(plotView)    
            plt.xlabel("epoch")
            plt.ylabel("min loss&avg loss&max loss")
            plt.show()
            
            plt.plot(plotAcc)    
            plt.xlabel("epoch")
            plt.ylabel("test accuracy")
            plt.show()
        print("----------Epoch:%03d/%03d cost:%.9f min:%.9f max:%.9f" % 
              (epoch, training_epochs, avg_cost, min_cost, max_cost)) 
        if test_acc > 0.9:
            saver.save(sess, os.path.abspath('MyDenseNetSaver/saver.ckpt'), global_step = epoch)
        times = 0
    print("OPTIMIZATION FINISHED")
    
def testOnce():
    label_batch_v, image_batch_v = VoiceToImg.GetBatchData(test_batch_size, -1, True)
    indices = tf.expand_dims(tf.range(0, test_batch_size, 1), 1)
    concated = tf.concat([indices, label_batch_v],1)
    onehot_labels = tf.sparse_to_dense(concated, tf.stack([test_batch_size, n_output]), 1.0, 0.0)
    onehot = sess.run(onehot_labels)
    test_acc = sess.run(accr, feed_dict = {
            x : image_batch_v, y : onehot, isTrain : False})
    print("TEST ACCURACY: %.3f" % (test_acc))
    return test_acc

Inited = False

def RecognizeData(data):
    if not Inited:
        restore('-100')
    res = sess.run(pred, feed_dict = {x : data, isTrain : False})
    return res
    
def restore(select_ctpk):
    saver.restore(sess, os.path.abspath('MyDenseNetSaver/saver.ckpt' + select_ctpk))
   